{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# imports and helper functions\n",
    "\n",
    "import os\n",
    "import sys\n",
    "import webdataset as wds\n",
    "import braceexpand\n",
    "import tempfile\n",
    "import glob\n",
    "from itertools import islice\n",
    "import random\n",
    "\n",
    "def summarize(sample):\n",
    "    for k, v in sample.items():\n",
    "        print(k, repr(v)[:100])\n",
    "\n",
    "def read_binary(fname):\n",
    "    with open(fname, \"rb\") as stream:\n",
    "        return stream.read()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Parallel Processing of Shards: Large Scale OCR\n",
    "\n",
    "This notebook illustrates how to take a large collection of shards consisting of PDFs and process them using `pdftoppm` and `tessearact` into a new dataset consisting of page images and corresponding OCR output.\n",
    "\n",
    "The general approach is to process each shard sequentially and to process multiple shards in parallel. The basic structure of such a job looks like:\n",
    "\n",
    "```Python\n",
    "with WebDataset(srcname) as src:\n",
    "    with TarWriter(dstname) as dst:\n",
    "        for sample in src:\n",
    "            ... do something with sample ...\n",
    "            dst.write(sample)\n",
    "upload(dstname)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# The Arxiv Dataset of PDFs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The dataset is tar files containing PDFs, each using the Arxiv naming convention.\n",
    "\n",
    "!gsutil cat gs://webdataset/testdata/arxiv-pdfs-{000000..000001}.tar | tar tf - | sed 5q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Arxiv naming convenitions are incompatible with WebDataset, but we can add\n",
    "# a file renaming function to the WebDataset to fix this.\n",
    "\n",
    "def arxiv_rename(name):\n",
    "    return name.replace(\".pdf\", \"\").replace(\".\", \"_\") + \".pdf\"\n",
    "\n",
    "# For this example, we just use two shards, but usually, you would have hundreds\n",
    "# or thousands of shards.\n",
    "\n",
    "dataset = \"gs://webdataset/testdata/arxiv-pdfs-{000000..000001}.tar\"\n",
    "\n",
    "# Let's open the dataset and read the first sample.\n",
    "\n",
    "shardurls = list(braceexpand.braceexpand(dataset))\n",
    "ds = wds.WebDataset(shardurls, rename_files=arxiv_rename)\n",
    "sample = next(iter(ds))\n",
    "summarize(sample)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Running Tesseract on a Single PDF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_sample(sample, maxpages=9999, shuffle=True):\n",
    "    \"\"\"Process a sample from the Arxiv dataset.\n",
    "\n",
    "    This function converts the PDF file to a sequence of JPEG images\n",
    "    and then invokes Tesseract to recognize the text in the images.\n",
    "    It returns a sequence of samples, one per page, each containing\n",
    "    the JPEG image and the hOCR output from Tesseract.\n",
    "    \"\"\"\n",
    "\n",
    "    # We work in a temporary directory; most operations are command line tools\n",
    "\n",
    "    with tempfile.TemporaryDirectory() as dirname:\n",
    "\n",
    "        # Write the PDF file to disk and convert it to a sequence of JPEGs using pdftoppm\n",
    "        pdfpath = dirname + \"/sample.pdf\"\n",
    "        with open(pdfpath, \"wb\") as stream:\n",
    "            stream.write(sample[\"pdf\"])\n",
    "        assert os.system(f\"(cd {dirname} && pdftoppm -forcenum -jpeg -r 300 -l 9999 sample.pdf page)\") == 0\n",
    "        \n",
    "        # Next, we are going to iterate over the pages, convert them to text using tesseract,\n",
    "        pages = sorted(glob.glob(dirname + \"/page-*.jpg\"))\n",
    "        if shuffle:\n",
    "            random.shuffle(pages)\n",
    "\n",
    "        for page in islice(pages, maxpages):\n",
    "            page_without_suffix = page[:-4]\n",
    "            base = os.path.basename(page_without_suffix)\n",
    "\n",
    "            # Invoke Tesseract to convert the page image to hOCR.\n",
    "            os.system(f\"tesseract {page} {page_without_suffix} hocr\")\n",
    "\n",
    "            # Construct the output sample.\n",
    "            nsample = {\n",
    "                \"__key__\": sample[\"__key__\"] + f\"/{base}\",\n",
    "                \"jpg\": read_binary(page_without_suffix + \".jpg\"),\n",
    "                \"hocr\": read_binary(page_without_suffix + \".hocr\"),\n",
    "            }\n",
    "\n",
    "            # This function returns an iterator over the recognized pages.\n",
    "            yield nsample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output = next(process_sample(sample))\n",
    "summarize(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Processing a Shard of PDF Files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_shard(src, dst, maxpdfs=999999, maxpages=9999):\n",
    "    \"\"\"Process a shard of the Arxiv dataset.\n",
    "\n",
    "    This function reads a shard of the Arxiv dataset, processes each sample\n",
    "    using the process_sample function, and writes the page images and corresponding\n",
    "    hOCR output to a new shard, one sample per page.\n",
    "\n",
    "    The maxpdfs and maxpages parameters can be used to limit the number of\n",
    "    samples and pages processed. This is useful for testing, as well as limit\n",
    "    the number of pages selected from very long PDF documents.\n",
    "    \"\"\"\n",
    "    with wds.TarWriter(dst) as sink:\n",
    "        for sample in islice(wds.WebDataset(src, rename_files=arxiv_rename), maxpdfs):\n",
    "            print(sample[\"__key__\"], sample.keys())\n",
    "            for nsample in process_sample(sample, maxpages=maxpages):\n",
    "                print(\"    \", nsample[\"__key__\"])\n",
    "                sink.write(nsample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!rm -f output.tar\n",
    "process_shard(shardurls[0], \"output.tar\", maxpdfs=2, maxpages=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!tar tvf output.tar"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Parallelizing Processing with Ray\n",
    "\n",
    "This illustrates how to use Ray to process many shards in parallel.\n",
    "\n",
    "You don't need to use Ray for this, you can also invoke `process_shard` in parallel using a job queueing system or using some other distributed computing framework.\n",
    "\n",
    "Generally, it is easiest to process each shard sequentially, and to process multiple shards in parallel. However, you could use additional parallelization to perform processing of the samples in parallel."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "maxpdfs = 2  # for testing, we just use two PDFs per shard\n",
    "maxpages = 2  # for testing, we just use two pages per PDF\n",
    "upload_cmd = \"echo gsutil cp {src} {dst}\"  # for testing, we don't actually upload the completed shards\n",
    "\n",
    "import ray\n",
    "if not ray.is_initialized():\n",
    "    ray.init()\n",
    "\n",
    "@ray.remote(num_cpus=4)\n",
    "def process_shard_parallel(src, dstbucket, maxpdfs=999999, maxpages=9999):\n",
    "    \"\"\"Process a shard of the Arxiv dataset and upload the output shard to a bucket.\n",
    "\n",
    "    This function reads a shard of the Arxiv dataset, processes each sample\n",
    "    using the process_sample function, and writes the page images and corresponding \n",
    "    hOCR output to a new shard, one sample per page. The output shard is then\n",
    "    uploaded to the specified bucket using `upload_cmd`.\n",
    "    \"\"\"\n",
    "    dst = dstbucket + \"/\" + os.path.basename(src)\n",
    "    with tempfile.NamedTemporaryFile() as tmp:\n",
    "        process_shard(src, tmp.name, maxpdfs=maxpdfs, maxpages=maxpages)\n",
    "        assert os.system(upload_cmd.format(src=tmp.name, dst=dst)) == 0\n",
    "\n",
    "!rm -f output.tar\n",
    "ray.get([process_shard_parallel.remote(src, \"gs://somebucket\", maxpdfs=maxpdfs, maxpages=maxpages) for src in shardurls])\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
